Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Fix AveragePool attributes support #3235

Merged
merged 1 commit into from
Jun 12, 2024
Merged

Conversation

AmosLewis
Copy link
Collaborator

@AmosLewis AmosLewis commented Apr 25, 2024

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Jun 3, 2024

The current code has fix the count_include_pad test case and get the Inception_v4_vaiq_int8 passed. Next need to clean the code.

python run.py -c ../../torch-mlir/build/ -i ../../iree-build/ -f onnx --tests onnx/operators/AveragePool --cachedir cachedir --report --torchtolinalg

tests model-run onnx-import torch-mlir iree-compile inference
onnx/operators/AveragePool passed passed passed passed passed

python ./run.py --torchmlirbuild ../../torch-mlir/build --tolerance 0.001 0.001 --cachedir ./huggingface_cache --ireebuild ../../iree-build -f onnx -g models --mode onnx --report --tests onnx/models/Inception_v4_vaiq_int8 --torchtolinalg
Status report for run: test-run using mode:onnx todtype:default backend:llvm-cpu

tests model-run onnx-import torch-mlir iree-compile inference
onnx/models/Inception_v4_vaiq_int8 passed passed passed passed failed

@AmosLewis AmosLewis force-pushed the avgpool2d branch 3 times, most recently from 76c07ee to 89ee8b7 Compare June 4, 2024 16:35
@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Jun 4, 2024

class AvgPool2dCountIncludePadFalseStaticModule(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.ap2d = torch.nn.AvgPool2d(
            kernel_size=[3, 3],
            stride=[1, 1],
            padding=[1, 1],
            ceil_mode=False,
            count_include_pad=False,
            divisor_override=None,
        )

    @export
    @annotate_args(
        [
            None,
            ([32, 384, 25, 25], torch.float32, True),
        ]
    )
    def forward(self, x):
        return self.ap2d(x)


@register_test_case(module_factory=lambda: AvgPool2dCountIncludePadFalseStaticModule())
def AvgPool2dCountIncludePadFalseStaticModule_basic(module, tu: TestUtils):
    module.forward(tu.rand(32, 384, 25, 25, low=-1))
python -m e2e_testing.main --config=onnx --filter AvgPool2dFloatStaticModule  -v                                                      
TORCH_VERSION_FOR_COMPARISON = 2.4.0.dev20240505
Compiling AvgPool2dFloatStaticModule_basic...

====================
ONNX RAW IR
module {
  func.func @main_graph(%arg0: !torch.vtensor<[32,384,25,25],f32>) -> !torch.vtensor<[32,384,25,25],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.count_include_pad = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[32,384,25,25],f32>) -> !torch.vtensor<[32,384,25,25],f32> 
    return %0 : !torch.vtensor<[32,384,25,25],f32>
  }
}


====================
TorchFX IR
module {
  func.func @main_graph(%arg0: !torch.vtensor<[32,384,25,25],f32>) -> !torch.vtensor<[32,384,25,25],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
    %none = torch.constant.none
    %int3 = torch.constant.int 3
    %int3_0 = torch.constant.int 3
    %int1 = torch.constant.int 1
    %int1_1 = torch.constant.int 1
    %int1_2 = torch.constant.int 1
    %int1_3 = torch.constant.int 1
    %0 = torch.prim.ListConstruct %int3, %int3_0 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int1, %int1_1 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.prim.ListConstruct %int1_2, %int1_3 : (!torch.int, !torch.int) -> !torch.list<int>
    %false = torch.constant.bool false
    %false_4 = torch.constant.bool false
    %none_5 = torch.constant.none
    %3 = torch.aten.avg_pool2d %arg0, %0, %2, %1, %false, %false_4, %none_5 : !torch.vtensor<[32,384,25,25],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[32,384,25,25],f32>
    return %3 : !torch.vtensor<[32,384,25,25],f32>
  }
}


====================
Torch Backend IR
module {
  func.func @main_graph(%arg0: !torch.vtensor<[32,384,25,25],f32>) -> !torch.vtensor<[32,384,25,25],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
    %false = torch.constant.bool false
    %none = torch.constant.none
    %int3 = torch.constant.int 3
    %int1 = torch.constant.int 1
    %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.avg_pool2d %arg0, %0, %1, %1, %false, %false, %none : !torch.vtensor<[32,384,25,25],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[32,384,25,25],f32>
    return %2 : !torch.vtensor<[32,384,25,25],f32>
  }
}


====================
LINALG Backend IR
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<32x384x25x25xf32>) -> tensor<32x384x25x25xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %c1_i64 = arith.constant 1 : i64
    %c25_i64 = arith.constant 25 : i64
    %c0_i64 = arith.constant 0 : i64
    %c2_i64 = arith.constant 2 : i64
    %c26_i64 = arith.constant 26 : i64
    %padded = tensor.pad %arg0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
      tensor.yield %cst : f32
    } : tensor<32x384x25x25xf32> to tensor<32x384x27x27xf32>
    %0 = tensor.empty() : tensor<32x384x25x25xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<32x384x25x25xf32>) -> tensor<32x384x25x25xf32>
    %2 = tensor.empty() : tensor<3x3xf32>
    %3 = linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded, %2 : tensor<32x384x27x27xf32>, tensor<3x3xf32>) outs(%1 : tensor<32x384x25x25xf32>) -> tensor<32x384x25x25xf32>
    %4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%3 : tensor<32x384x25x25xf32>) outs(%0 : tensor<32x384x25x25xf32>) {
    ^bb0(%in: f32, %out: f32):
      %5 = linalg.index 2 : index
      %6 = arith.index_cast %5 : index to i64
      %7 = linalg.index 3 : index
      %8 = arith.index_cast %7 : index to i64
      %9 = arith.subi %6, %c1_i64 : i64
      %10 = arith.subi %8, %c1_i64 : i64
      %11 = arith.addi %6, %c2_i64 : i64
      %12 = arith.minsi %11, %c26_i64 : i64
      %13 = arith.addi %8, %c2_i64 : i64
      %14 = arith.minsi %13, %c26_i64 : i64
      %15 = arith.maxsi %9, %c0_i64 : i64
      %16 = arith.maxsi %10, %c0_i64 : i64
      %17 = arith.minsi %12, %c25_i64 : i64
      %18 = arith.minsi %14, %c25_i64 : i64
      %19 = arith.subi %17, %15 : i64
      %20 = arith.subi %18, %16 : i64
      %21 = arith.muli %19, %20 : i64
      %22 = arith.sitofp %21 : i64 to f32
      %23 = arith.divf %in, %22 : f32
      linalg.yield %23 : f32
    } -> tensor<32x384x25x25xf32>
    return %4 : tensor<32x384x25x25xf32>
  }
}

Running AvgPool2dFloatStaticModule_basic...
PASS - "AvgPool2dFloatStaticModule_basic"

Summary:
    Passed: 1

@AmosLewis AmosLewis force-pushed the avgpool2d branch 6 times, most recently from 4379be2 to 7027ace Compare June 7, 2024 17:19
@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Jun 10, 2024

Fix an AvgPool2dIntModule_basic crash for onnx. Get a new linalg e2e crash. Then Get a new tosa/stablehlo pass.

Torch Backend IR
module attributes {torch.debug_module_name = "AvgPool2dDivisorOverrideModule"} {
  func.func @forward(%arg0: !torch.vtensor<[4,4,20,20],f32>) -> !torch.vtensor<[4,4,11,7],f32> {
    %int4 = torch.constant.int 4
    %int8 = torch.constant.int 8
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %false = torch.constant.bool false
    %true = torch.constant.bool true
    %int22 = torch.constant.int 22
    %0 = torch.prim.ListConstruct %int4, %int8 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.prim.ListConstruct %int2, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %true, %int22 : !torch.vtensor<[4,4,20,20],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[4,4,11,7],f32>
    return %3 : !torch.vtensor<[4,4,11,7],f32>
  }
}

python: /home/chi/src/torch-mlir/lib/Conversion/Utils/Utils.cpp:327: Value mlir::torch::Torch::convertScalarToDtype(OpBuilder &, Location, Value, Type, std::optional<Type>, std::optional<Type>, std::optional<Value>): Assertion `isa<mlir::IntegerType>(scalarType)' failed.
[2]    1139312 IOT instruction (core dumped)  python -m e2e_testing.main --config=linalg -v -s

@AmosLewis AmosLewis force-pushed the avgpool2d branch 3 times, most recently from ff83f5b to f7f5a7f Compare June 10, 2024 20:59
- [ONNX] Fix padding attributes for onnx.AveragePool
- [Linalg] Add countIncludePad false support for AtenAvgPool1/2dOp
- [Linalg] Add an avg_pool2d countIncludePad False e2e tests
- [Linalg] Fix conflict with AtenAvgPool3dOp
- [Linalg] Fix e2e crash with AtenAvgPool1dOp
- [Linalg] Add dynamic dim support for AtenAvgPool2dOp
- [Linalg] Fix AvgPool2dDivisorOverrideModule crash
@rsuderman
Copy link
Contributor

You should check whether this maps well to the existing pooling linalg structured ops:
https://github.com/iree-org/llvm-project/blob/534590144f7c7ec34b8e5e95aba3e4f214b074eb/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml#L5051

We have specialized version that are computationally more efficient.

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Jun 10, 2024

You should check whether this maps well to the existing pooling linalg structured ops: https://github.com/iree-org/llvm-project/blob/534590144f7c7ec34b8e5e95aba3e4f214b074eb/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml#L5051

We have specialized version that are computationally more efficient.

We are using the linalg::PoolingNchwSumOp not linalg::PoolingNhwcSumOp you send. And I didn't change the sumpool part of code, so map shouldn't be an issue.

https://github.com/iree-org/llvm-project/blob/534590144f7c7ec34b8e5e95aba3e4f214b074eb/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml#L5134

@AmosLewis AmosLewis merged commit ae6f5e8 into llvm:main Jun 12, 2024
3 checks passed
@AmosLewis
Copy link
Collaborator Author

#3428 will need refer this patch

@nirvedhmeshram
Copy link
Collaborator

FYI looks like there is a regression with average pool that needs to be looked at
https://github.com/iree-org/iree/actions/runs/9523402840/job/26254914840?pr=17662#step:8:55

@AmosLewis
Copy link
Collaborator Author

FYI looks like there is a regression with average pool that needs to be looked at https://github.com/iree-org/iree/actions/runs/9523402840/job/26254914840?pr=17662#step:8:55

It's been fixed one my side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants